import os
from re import L
from turtle import forward
import torch
from torch import nn
import numpy as np
import math
import torch.nn.init as init
from torch.nn import functional as F
from torch.autograd import Variable
from pdb import set_trace as stop
from transformers import BertModel,AutoModelForMaskedLM,AutoTokenizer
import warnings
warnings.filterwarnings("ignore")


class EmoCos(nn.Module):
    def __init__(self, opt,emo2idx,cos2idx) -> None:
        super(EmoCos,self).__init__()
        self.opt=opt
        self.emo2idx=emo2idx
        self.cos2idx=cos2idx
        self.bert = BertModel.from_pretrained(opt.bert_path)
        self.tokenizer = AutoTokenizer.from_pretrained(opt.bert_path)
        self.hsize=500
      
        self.simple_fusion = nn.Sequential()   
      
        if self.opt.task == 'CF':
            self.emo_embedding = nn.Embedding(len(emo2idx), opt.hidden_size)
           
            self.fea_emo = nn.Sequential()
            self.fea_utt = nn.Sequential()
            self.cos_MLP = nn.Sequential(
               )
        
        if self.opt.task == 'ECPF' or self.opt.task == 'ECPF-C':
            
            self.fea_emo = nn.Sequential()
            self.fea_utt = nn.Sequential()
            self.emo_MLP = nn.Sequential()
            self.cos_MLP = nn.Sequential()
         
    def forward(self,Speaker_input_ids, Speaker_input_attention_mask, input_ids, input_attention_mask,\
                describe_input_ids, describe_attention_masks, describe_token_type_ids,\
                emo_clue_input_ids, emo_clue_attention_mask, emo_clue_token_type_ids,\
                cos_clue_why_input_ids, cos_clue_why_attention_mask, cos_clue_why_token_type_ids,\
                cos_clue_impact_input_ids, cos_clue_impact_attention_mask, cos_clue_impact_token_type_ids,\
                Emotion_input_ids, Emotion_input_attention_mask, Cosequence_input_ids,\
                Cosequence_input_attention_mask, video_feature, video_feature_attention_mask,\
                audio_feature, audio_feature_attention_mask,spk_and_input_ids,\
                spk_and_input_ids_attention_mask,spk_and_input_token_type_ids,\
                crition_foremo,crition_forcos, opt, epoch) -> None:
       
        pooler_output = self.get_pooler_output(spk_and_input_ids,spk_and_input_ids_attention_mask,spk_and_input_token_type_ids)

        if opt.select_modality =='TAVC':
            # for clues 
            if self.opt.task != 'CF':
                emo_clue_pooler_output= self.get_pooler_output(emo_clue_input_ids,emo_clue_attention_mask,emo_clue_token_type_ids)
            cos_clue_why_pooler_output= self.get_pooler_output(cos_clue_why_input_ids,cos_clue_why_attention_mask,cos_clue_why_token_type_ids)
            cos_clue_impact_pooler_output= self.get_pooler_output(cos_clue_impact_input_ids,cos_clue_impact_attention_mask,cos_clue_impact_token_type_ids)
            # for clues end

        batch_size,seq_len,_ =spk_and_input_ids.shape
        # text+video+audio
        mm_feature_cat = self.get_modality_cat(pooler_output,video_feature,audio_feature)
      
        mm_feature = self.simple_fusion(mm_feature_cat)
        lstm_mm_feature,*_ = self.rnn_encoding(mm_feature)

        

        if self.opt.task == 'CF':
            emo_emb = self.emo_embedding(Emotion_input_ids)
            # for emo emo_clues, to CF no emo clues neeed
            if opt.select_modality !='TAVC':
                emo_feature = self.fea_emo(torch.cat((lstm_mm_feature,emo_emb),dim=2))
                utt_feature = self.fea_utt(lstm_mm_feature)
            else:
                emo_feature = self.fea_emo(torch.cat((lstm_mm_feature,emo_emb,cos_clue_impact_pooler_output),dim=2))
                utt_feature = self.fea_utt(torch.cat((lstm_mm_feature,cos_clue_why_pooler_output),dim=2))

            

            # for cos cos_clues
          
         
            pred_cos,cos_score = self.get_cos_score()
            mask = Cosequence_input_attention_mask*emo_used_mask_reapeated
            Cosequence_input_ids[mask==0]=-1
            dense_cos_score = cos_score.reshape(-1,cos_score.shape[-1])
            dense_Cosequence_input_ids=Cosequence_input_ids.reshape(-1)
          
            loss = crition_forcos(dense_cos_score,dense_Cosequence_input_ids)
            
            ones = torch.ones_like(pred_cos)
            cos_used_mask4 = torch.triu(ones,diagonal=1)
            mask4 = emo_used_mask_reapeated*cos_used_mask4
            pred_cos[mask4==0]=-1
           
            return loss,Emotion_input_ids,Emotion_input_ids,pred_cos,Cosequence_input_ids





        
        if self.opt.task =='ECPF' or self.opt.task =='ECPF-C':
            # very similar
            if opt.select_modality !='TAVC':
                emo_feature = self.fea_emo(lstm_mm_feature)
                utt_feature = self.fea_utt(lstm_mm_feature)
            else:
                emo_feature = self.fea_emo(torch.cat((lstm_mm_feature,emo_clue_pooler_output),dim=2))
                emo_feature_impact = self.fea_emo_impact(torch.cat((lstm_mm_feature,cos_clue_impact_pooler_output),dim=2))
                utt_feature = self.fea_utt(torch.cat((lstm_mm_feature,cos_clue_why_pooler_output),dim=2))

            
            emo_scores = self.emo_MLP(emo_feature)
            mask1 = Emotion_input_attention_mask
            Emotion_input_ids[mask1==0]=-1
            dense_emo_scores = emo_scores.reshape(-1,emo_scores.shape[-1])
            dense_Emotion_input_ids=Emotion_input_ids.reshape(-1)
            loss_emo = crition_foremo(dense_emo_scores,dense_Emotion_input_ids)
            pred_emos = emo_scores.argmax(dim=2)

            pred_emos[mask1==0]=0
          
            # for cos
            if opt.select_modality !='TAVC':
                emo_feature_repeated = emo_feature.unsqueeze(2).repeat(1,1,seq_len,1)
            else:
                emo_feature_repeated = emo_feature_impact.unsqueeze(2).repeat(1,1,seq_len,1)
           
            pred_cos,cos_score = self.get_cos_score()
     
            mask2 = Cosequence_input_attention_mask*emo_used_mask_reapeated
            cos_mask_input_ids = Cosequence_input_ids.clone()
            cos_mask_input_ids[mask2==0]=-1
            dense_cos_score = cos_score.reshape(-1,cos_score.shape[-1])
            dense_Cosequence_input_ids=cos_mask_input_ids.reshape(-1)
           
            loss_cos = crition_forcos(dense_cos_score,dense_Cosequence_input_ids)
            loss = (loss_emo+loss_cos)/2

            emo_used_mask3 = (torch.gt(Emotion_input_ids,0)*1).unsqueeze(2)
            emo_used_mask3_reapeated = emo_used_mask3.repeat(1,1,seq_len)
            mask3 = Cosequence_input_attention_mask*emo_used_mask3_reapeated
            Cosequence_input_ids[mask3==0]=-1

           
            ones = torch.ones_like(pred_cos)
            cos_used_mask4 = torch.triu(ones,diagonal=1)
            mask4 = emo_used_mask_reapeated*cos_used_mask4
            pred_cos[mask4==0]=-1
            
            # if math.isnan(loss):
            #     stop()
            return loss,pred_emos,Emotion_input_ids,pred_cos,Cosequence_input_ids
        


